import time

import torch
from torch.utils.data import DataLoader
from torchvision.models.detection import (
    keypointrcnn_resnet50_fpn,
    KeypointRCNN_ResNet50_FPN_Weights,
)
from yolov3.yolo import YOLOv3

from image_dataset import ImageDataset


def inference_pipeline(
    fp_input: str,
    fp_output: str,
    model_type: str = "yolo_v3",
    img_size: int = 416,
    batch_size: int = 16,
    n_workers: int = 2,
):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if model_type == "mask_rcnn":
        weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
        model = keypointrcnn_resnet50_fpn(weights=weights).to(device)
        model = model.eval()
        pred_col = "max_score"
    elif model_type == "yolo_v3":
        model = YOLOv3(
            device=device, person_detector=True, return_dict=True, img_size=img_size
        )
        pred_col = "n_people"
    else:
        raise ValueError("'model_type' must be one of 'mask_rcnn' or 'yolo_v3'")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    dataset = ImageDataset(fp_input, model_type=model_type, img_size=img_size)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers)

    log_file = open(fp_output, "w")
    log_file.write(f"qry_id,img_rank,{pred_col}\n")

    st = time.time()
    for i, (qry_id, img_rank, img_data) in enumerate(loader):

        with torch.no_grad():
            preds = model(img_data.to(device))

        if model_type == "mask_rcnn":
            preds = [p["scores"].cpu().numpy()[0] for p in preds]
        else:
            preds = [len(p["scores"]) for p in preds]

        for q, r, p in zip(qry_id.numpy(), img_rank.numpy(), preds):
            log_file.write(f"{q},{r},{p}\n")
            log_file.flush()

        if (i + 1) * batch_size % 10_000 == 0:
            print(
                f"Processing 10_000 images took {time.time()-st}s. {(i + 1) * batch_size} images have been processed"
            )
            st = time.time()


if __name__ == "__main__":
    inference_pipeline()
